Cross-validation

Elizabeth King
Kevin Middleton

Models learn from data

Overfitting:

  • Model learns too much.
  • Can’t predict well.

Underfitting:

  • Model doesn’t learn enough.
  • Can’t predict well.

Resampling methods1

Model assessment

  • How good is a model at predicting out-of-sample data?

Model selection

  • What model has the optimal “flexibility”?

Machine learning lingo

  • Error rates: Training, testing (validation)
  • Bias: Underfitting
  • Variance: Overfitting
  • Loss or Cost function: Measure of the prediction error

Challenges of low power designs (e.g. “small” data sets)

  • Big data sets can be split easily,
    • Plenty to training
    • Plenty for (validation and) testing
  • More error when data is trained on small data sets (underfitting)

Dividing data

Cross-validation

  • k-fold (v-fold)
  • Leave-one-out (LOOCV)

General CV approach

  • Split data into training and test sets
  • Use training set to build model
  • Use test set to evaluate model (predict new values)
    • Calculate prediction error
  • Repeat for k folds to get average error

Choice of error term

  • (Root) Mean Squared Error (RMSE / MSE)
    • Can be dominated by extreme values (e.g., outliers)
  • Mean Absolute Error (MAE)
    • Related to the median
  • Correlation
    • Observations & Predicted values
    • Less common

Laborious example of 5-fold CV

set.seed(3462834)
DD <- tibble(x = runif(10, 0, 20),
             y = x * 2 + rnorm(10, 5, 10))

ggplot(DD, aes(x, y)) +
  geom_smooth(formula = y ~ x, method = "lm", se = FALSE,
              color = "lightgoldenrod3",
              linewidth = 2) +
  geom_point(size = 5, color = "mediumorchid4") +
  scale_y_continuous(limits = c(8, 50))

Laborious example of 5-fold CV

Assign folds

DD <- DD |> 
  mutate(k = sample(rep(1:5, each = 2))) |> 
  arrange(k)
DD
# A tibble: 10 × 3
       x     y     k
   <dbl> <dbl> <int>
 1 13.9  28.9      1
 2  2.30  9.12     1
 3 16.5  41.7      2
 4  9.38 26.5      2
 5 14.5  30.2      3
 6  1.65 22.9      3
 7 11.5  16.2      4
 8 17.9  45.7      4
 9 17.3  42.4      5
10  6.77 46.1      5

Assign folds

Fold 1

DD1 <- DD |> filter(k != 1)
lm1 <- lm(y ~ x, data = DD1)

(Pred1 <- DD |> filter(k == 1) |> 
    mutate(y_hat = predict(lm1, newdata = DD |> filter(k == 1))))
# A tibble: 2 × 4
      x     y     k y_hat
  <dbl> <dbl> <int> <dbl>
1 13.9  28.9      1  35.8
2  2.30  9.12     1  24.9

Fold 1

Fold 1

Fold 1

Fold 1

Fold 1

Pred1 <- Pred1 |> 
  mutate(Error = y - y_hat,
         Squared_error = Error ^ 2,
         Absolute_error = abs(Error))
Pred1
# A tibble: 2 × 7
      x     y     k y_hat  Error Squared_error Absolute_error
  <dbl> <dbl> <int> <dbl>  <dbl>         <dbl>          <dbl>
1 13.9  28.9      1  35.8  -6.91          47.7           6.91
2  2.30  9.12     1  24.9 -15.8          248.           15.8 

Function to do all folds

CV_fun <- function(k_fold, DD) {
  DDk <- DD |> filter(k != k_fold)
  lmk <- lm(y ~ x, data = DDk)
  
  DD |> filter(k == k_fold) |> 
    mutate(y_hat = predict(lmk, newdata = DD |> filter(k == k_fold)),
           Error = y - y_hat,
           Squared_error = Error ^ 2,
           Absolute_error = abs(Error))
}

CV_fun(1, DD)
# A tibble: 2 × 7
      x     y     k y_hat  Error Squared_error Absolute_error
  <dbl> <dbl> <int> <dbl>  <dbl>         <dbl>          <dbl>
1 13.9  28.9      1  35.8  -6.91          47.7           6.91
2  2.30  9.12     1  24.9 -15.8          248.           15.8 
Pred1
# A tibble: 2 × 7
      x     y     k y_hat  Error Squared_error Absolute_error
  <dbl> <dbl> <int> <dbl>  <dbl>         <dbl>          <dbl>
1 13.9  28.9      1  35.8  -6.91          47.7           6.91
2  2.30  9.12     1  24.9 -15.8          248.           15.8 

5-fold CV

CV5 <- future_map(.x = 1:5,
                  .f = CV_fun,
                  DD = DD) |> 
  list_rbind()
CV5
# A tibble: 10 × 7
       x     y     k y_hat  Error Squared_error Absolute_error
   <dbl> <dbl> <int> <dbl>  <dbl>         <dbl>          <dbl>
 1 13.9  28.9      1  35.8  -6.91         47.7            6.91
 2  2.30  9.12     1  24.9 -15.8         248.            15.8 
 3 16.5  41.7      2  37.3   4.44         19.7            4.44
 4  9.38 26.5      2  28.5  -1.99          3.95           1.99
 5 14.5  30.2      3  36.2  -5.98         35.8            5.98
 6  1.65 22.9      3  15.7   7.23         52.3            7.23
 7 11.5  16.2      4  32.5 -16.3         265.            16.3 
 8 17.9  45.7      4  40.4   5.33         28.4            5.33
 9 17.3  42.4      5  37.6   4.76         22.7            4.76
10  6.77 46.1      5  21.1  25.0         627.            25.0 

5-fold CV

mean(CV5$Squared_error)
[1] 135.0853
sqrt(mean(CV5$Squared_error))
[1] 11.62262
mean(CV5$Absolute_error)
[1] 9.371999

Comparing models

Comparing models

CV5_0 <- future_map(.x = 1:5,
                    .f = CV_fun_0,
                    DD = DD) |> 
  list_rbind()
CV5_0
# A tibble: 10 × 7
       x     y     k y_hat  Error Squared_error Absolute_error
   <dbl> <dbl> <int> <dbl>  <dbl>         <dbl>          <dbl>
 1 13.9  28.9      1  34.0  -5.08         25.8            5.08
 2  2.30  9.12     1  34.0 -24.9         618.            24.9 
 3 16.5  41.7      2  30.2  11.5         133.            11.5 
 4  9.38 26.5      2  30.2  -3.66         13.4            3.66
 5 14.5  30.2      3  32.1  -1.84          3.38           1.84
 6  1.65 22.9      3  32.1  -9.19         84.5            9.19
 7 11.5  16.2      4  31.0 -14.7         218.            14.7 
 8 17.9  45.7      4  31.0  14.7         216.            14.7 
 9 17.3  42.4      5  27.7  14.7         215.            14.7 
10  6.77 46.1      5  27.7  18.5         341.            18.5 

Comparing models

Parameter OLS model Intercept Only
MSE 135.09 186.79
RMSE 11.62 13.67
MAE 9.37 11.88

RMSE and MAE are interpreted on the “natural scale” of the data

Leave-one-out cross-validation (LOOCV)

  • Fitting to the most points minus 1
  • Should have less underfitting
    • Uses maximal number of observations while leaving one out

LOOCV on our sample data

LOOCV_fun <- function(ii, DD) {
  DDk <- DD |> slice(-ii)
  lmk <- lm(y ~ x, data = DDk)
  
  DD |> slice(ii) |> 
    mutate(y_hat = predict(lmk, newdata = DD |> slice(ii)),
           Error = y - y_hat,
           Squared_error = Error ^ 2,
           Absolute_error = abs(Error))
}

LOOCV on our sample data

LOOCV <- future_map(.x = seq_len(nrow(DD)),
                  .f = LOOCV_fun,
                  DD = DD) |> 
  list_rbind()
LOOCV
# A tibble: 10 × 7
       x     y     k y_hat  Error Squared_error Absolute_error
   <dbl> <dbl> <int> <dbl>  <dbl>         <dbl>          <dbl>
 1 13.9  28.9      1  35.4  -6.46         41.7            6.46
 2  2.30  9.12     1  24.6 -15.5         240.            15.5 
 3 16.5  41.7      2  37.1   4.61         21.3            4.61
 4  9.38 26.5      2  28.9  -2.34          5.47           2.34
 5 14.5  30.2      3  36.2  -5.99         35.9            5.99
 6  1.65 22.9      3  15.7   7.24         52.4            7.24
 7 11.5  16.2      4  33.2 -16.9         286.            16.9 
 8 17.9  45.7      4  38.1   7.64         58.3            7.64
 9 17.3  42.4      5  38.1   4.23         17.9            4.23
10  6.77 46.1      5  21.2  24.9         622.            24.9 

Model comparison with LOOCV

Parameter OLS model Intercept Only
MSE 138.12 182.21
RMSE 11.75 13.50
MAE 9.59 11.55

5-fold

Parameter OLS model Intercept Only
MSE 135.09 186.79
RMSE 11.62 13.67
MAE 9.37 11.88

LOOCV

Parameter OLS model Intercept Only
MSE 138.12 182.21
RMSE 11.75 13.50
MAE 9.59 11.55

k-Fold or LOOCV?

LOOCV shortcut

Leverage:

\[h_i = \frac{1}{n} + \frac{(x_i - \bar{x})^2}{\sum^{n}_{i'=1}(x_{i'} - \bar{x})}\]

\[CV_{(n)} = \frac{1}{n} \sum^{n}_{i=1} \left( \frac{y_i - \hat{y_i}}{1 - h_i} \right) ^2\]

CV Logistic regression

0-1 error term

  • Correct or incorrect classification
  • Percent correct
  • Confusion matrix

Presence of sole

Presence of sole

Presence/absence of sole in the Tagus estuary of Portugal.

tibble [65 × 13] (S3: tbl_df/tbl/data.frame)
 $ Sample       : num [1:65] 1 2 3 4 5 6 7 8 9 10 ...
 $ season       : num [1:65] 1 1 1 1 1 1 1 1 1 1 ...
 $ month        : num [1:65] 5 5 5 5 5 5 5 5 5 5 ...
 $ area         : num [1:65] 2 2 2 4 4 4 3 3 3 1 ...
 $ depth        : num [1:65] 3 2.6 2.6 2.1 3.2 3.5 1.6 1.7 1.8 4.5 ...
 $ temperature  : num [1:65] 20 18 19 20 20 20 19 17 19 21 ...
 $ salinity     : num [1:65] 30 29 30 29 30 32 29 28 29 12 ...
 $ transparency : num [1:65] 15 15 15 15 15 7 15 10 10 35 ...
 $ gravel       : num [1:65] 3.74 1.94 2.88 11.06 9.87 ...
 $ large_sand   : num [1:65] 13.15 4.99 8.98 11.96 28.6 ...
 $ med_fine_sand: num [1:65] 11.93 5.43 16.85 21.95 19.49 ...
 $ mud          : num [1:65] 71.2 87.6 71.3 55 42 ...
 $ Solea_solea  : Factor w/ 2 levels "0","1": 1 1 2 1 1 1 2 2 1 2 ...

Presence of sole

DFA & Cross Validation

fm <- lda(Solea_solea ~ salinity, data = MM, CV = TRUE)

table(list(predicted=fm$class, observed=MM$Solea_solea))
         observed
predicted  0  1
        0 34 11
        1  5 15

More Resources

James, G., D. Witten, T. Hastie, and R. Tibshirani. 2013. An Introduction to Statistical Learning